import numpy as np
import pandas as pd
import numba
from scipy.stats import norm

from cat_rfs.code.utils.tools import list_iterate


@list_iterate('date_list', chunksize=1, output='pd')
def simulate(df, n, var='return', lambda_='actual', ew=False, include_skew=False,
             randomize_perils=False, equalize_el=False, date_list=None, issue_level_corr=None):
    """
    Estimate beta (and skewness) of cat bonds by simulating returns or losses based on risk-model outputs.

    Parameters
    ----------
    df: pandas.DataFrame
        A panel data of cat bonds with CUSIP9-date as primary key.
    n: positive integer
        Number of simulation trials.
    var: {'return', 'loss'}, default 'return'
        If 'return', calculate betas of bond returns with respect to market return.
        If 'loss', calculate betas of bond losses with respect to market return. The benefit of this is that
        it does not require yield data, and in most cases should give very similar result because expected losses
        dominate beta calculations.
    lambda_: 'actual' or float, default 'actual'
        This parameter determines the shape of the loss function in the domain of partial losses. If 'actual',
        calculate the parameter value for each bond separately so that the approximated loss distribution has same
        expected loss as the observed one. If float, use this value as parameter value for all bonds
        (for example, if lambda_=1, loss function is linear.
    ew: boolean, default False
        If True, use equally weighted market portfolio. Else, use face-value weighted.
    include_skew: boolean, default False
        If True, include return skewness to output.
    randomize_perils: boolean, default False
        If True, randomly reassign each bond to a hypothetical peril category (typically for placebo test purposes).
        The number of categories is the same as in the actual sample, but the probability of entering any given category
        is the same across all categories.
    equalize_el: False or float, default False
        If value is set to f=[0, 100], set attachment probabilities, expected losses and exhaustion
        probabilities equal to f when simulating losses.
    date_list: list, optional
        If defined, estimate betas only for dates in this list.
    issue_level_corr: float, optional
        If defined, assume that disasters are perfectly correlated only at tranche-level, with bonds in same peril
        category otherwise having imperfect correlation defined by this variable.
        This is useful for testing the robustness of the beta estimates.
    """
    assert issue_level_corr is None or 0 <= issue_level_corr <= 1, 'issue_level_corr must be between 0 and 1.'
    assert date_list is not None,  'date_list is actually needed.'
    if date_list:
        df = df[df['date'].isin(date_list)].copy()

    # %% Run simulations. Creates a data frame at date-CUSIP9-trial level that contains simulated returns for each bond
    # In that particular simulation scenario. df2 contains market-level data for scenarios.
    df, df2 = _simulate(df, n, var=var, lambda_=lambda_, ew=ew, randomize_perils=randomize_perils,
                        equalize_el=equalize_el, include_P=False, issue_level_corr=issue_level_corr)

    df = df.merge(df2[['date', 'trial', 're_m']], on=['date', 'trial'])

    # %% Estimate betas (and skewness) for each bond and date over the simulated returns.
    df3 = df.groupby(['CUSIP9', 'date'])[['re', 're_m']].apply(lambda x: _beta(x.values)).reset_index()

    df3 = df3.merge(df2.groupby('date')['re_m'].var().reset_index(), on='date', how='left').rename(columns={'re_m': 'var_m'})
    df3 = df3.merge(df.groupby(['CUSIP9', 'date'])['re'].var().reset_index(), on=['CUSIP9', 'date'], validate='1:1',
                    how='left').rename(columns={'re': 'var'})

    if include_skew:
        df4 = df.groupby(['CUSIP9', 'date'])[['re']].skew().reset_index()
        df3 = df3.merge(df4, on=['CUSIP9', 'date'], how='outer', validate='1:1')

    df3 = df3.rename(columns={0: 'beta', 're': 'skew'})

    return df3


@list_iterate('date_list', chunksize=1, output='pd')
def simulate_mkt(df, n, var='return', lambda_='actual', ew=False, randomize_perils=False, date_list=None):
    """Estimate cat bond market volatility by simulating returns or losses based on risk-model outputs."""
    assert date_list is not None, 'date_list is actually needed.'
    if date_list:
        df = df[df['date'].isin(date_list)].copy()

    df, df2 = _simulate(df, n, var=var, lambda_=lambda_, ew=ew, randomize_perils=randomize_perils)
    del df

    df2 = df2.groupby('date')['re_m'].var().reset_index().rename(columns={'re_m': 'var_m'})

    return df2


def _simulate(df, n, var='return', lambda_='actual', ew=False, randomize_perils=False, equalize_el=False,
              include_P=True, issue_level_corr=None):
    """
    Create simulated return data frames at security level (security-date-trial) and market level (date-trial).
    Return as a tuple of data frames.
    """
    perils = list(df.loc[~df['perils'].str.contains(','), 'perils'].unique())
    df['CUSIP6'] = df['CUSIP9'].str[:6]
    df['peril_group'] = df['peril_group'].fillna(0)

    # %% Make sure required columns are present
    req_cols = ['CUSIP9', 'date', 'perils', 'rf', 'attachment_prob', 'expected_loss', 'exhaustion_prob']
    req_cols = req_cols + ['yield'] if var == 'return' else req_cols
    req_cols = req_cols + ['size'] if not ew else req_cols
    req_cols = req_cols + ['CUSIP6', 'issue_num', 'peril_group'] if issue_level_corr is not None else req_cols
    req_cols = req_cols + [f"ep_{c.replace(' ' , '_').lower()}" for c in perils]
    req_cols = req_cols + [f"ap_{c.replace(' ' , '_').lower()}" for c in perils]
    assert all([i in df.columns for i in req_cols])
    assert not df[req_cols].isnull().any().any()

    df = df[req_cols].copy()

    # %% If true, assign bonds to random peril category
    if randomize_perils:
        df['perils'] = np.random.choice(perils, df.shape[0])
        df[[f"ep_{c.replace(' ' , '_').lower()}" for c in perils]] = 0
        df[[f"ep_{c.replace(' ' , '_').lower()}" for c in perils]] = 0

        perils = list(df['perils'].unique())

        for peril in perils:
            df.loc[df['perils'] == peril, f"ep_{peril.replace(' ' , '_').lower()}"] = df['exhaustion_prob']
            df.loc[df['perils'] == peril, f"ap_{peril.replace(' ' , '_').lower()}"] = df['attachment_prob']
        if issue_level_corr is None and equalize_el is False:
            df = df.drop(columns=['perils'])

    # %% If true, set expected loss equal to equalize_el for all bonds
    if equalize_el is not False:
        assert 0 <= equalize_el <= 100

        def _count_matches(x, p):
            perils_set = set(x.split(', '))
            return sum(1 for item in p if item in perils_set)

        df['matches'] = df['perils'].apply(lambda x: _count_matches(x, perils))

        for p in perils:
            for v in ('ap', 'ep'):
                vname = f"{v}_{p.replace(' ', '_').lower()}"
                df.loc[df['perils'].str.contains(p), vname] = (1 - (1- equalize_el / 100) ** (1 / df['matches'])) * 100
        df['attachment_prob'] = equalize_el
        df['expected_loss'] = equalize_el
        df['exhaustion_prob'] = equalize_el

        df = df.drop(columns=['matches'])

    # %% Calculate lambda parameter that controls the curvature of loss function in the domain of partial losses
    if lambda_ == 'actual':
        df['lambda_'] = ((df['attachment_prob'] - df['exhaustion_prob'])
                         / (df['attachment_prob'] - df['expected_loss'])) - 1
    else:
        assert lambda_ > 0
        df['lambda_'] = lambda_

    # %% Draw random disaster outcomes in each trial. These are uniformly distributed and lower number indicates
    # more severe disasters.
    if issue_level_corr is not None:
        df['runvar'] = df.groupby(['perils', 'CUSIP6', 'issue_num', 'peril_group']).ngroup()
        df = df.drop(columns=['CUSIP6', 'perils', 'issue_num', 'peril_group'])
        df_rand = pd.DataFrame(columns=['trial', 'runvar'])
        runvars = df['runvar'].max()
        for peril in perils:
            corrmat = np.full((runvars, runvars), issue_level_corr)
            np.fill_diagonal(corrmat, 1.0)
            t = np.random.multivariate_normal(np.zeros(runvars), corrmat, n)
            cdf = pd.DataFrame(norm.cdf(t))
            cdf['trial'] = cdf.index
            cdf = cdf.melt(id_vars='trial', var_name='runvar', value_name=f"dmg_{peril.replace(' ' , '_').lower()}")
            df_rand = df_rand.merge(cdf, on=['trial', 'runvar'], how='outer')

        df = df.merge(df_rand, on=['runvar'])
    else:
        df_rand = pd.DataFrame(range(0, n), columns=['trial'])
        df_rand[[f"dmg_{c.replace(' ' , '_').lower()}" for c in perils]] = np.random.rand(n, len(perils))

        df_rand['key'] = 1
        df['key'] = 1
        df = df.merge(df_rand, on='key')
        df = df.drop(columns=['key'])

    del df_rand

    # %% Calculate bond price P in a given disaster scenario.
    ps = []
    for n, peril in enumerate(perils):
        dmgcol = f"dmg_{peril.replace(' ' , '_').lower()}"
        apcol = f"ap_{peril.replace(' ' , '_').lower()}"
        epcol = f"ep_{peril.replace(' ' , '_').lower()}"
        df = df.eval(f'P{n} = ((({dmgcol} * 100 - {epcol}) / ({apcol} - {epcol})) ** lambda_)').copy()

        df.loc[df.eval(f'({dmgcol} * 100 <= {epcol})'), f'P{n}'] = 0  # Note: This also takes care of binary bonds.
        df.loc[df.eval(f'({dmgcol} * 100 >= {apcol})'), f'P{n}'] = 1

        ps += [f'P{n}']
    df['P'] = df[ps].min(axis='columns')

    cols = ['CUSIP9', 'P', 'trial', 'date']
    cols = cols + ['yield', 'rf'] if var == 'return' else cols
    cols = cols + ['size'] if not ew else cols
    df = df[cols].copy()

    # %% Calculate bond return
    if var == 'return':
        df['re'] = df['P'] * (1 + df['yield']) - 1 - df['rf'] / 100
        df = df.drop(columns=['yield', 'rf'])
    elif var == 'loss':
        df['re'] = df['P'] - 1

    # %% Calculate market portfolio return.
    if not ew:
        df = df.eval('re_m = re * size')
        varlist = ['re_m', 'size']
        if include_P:
            df = df.eval('P = P * size')
            varlist += ['P']

        df2 = df.groupby(['date', 'trial'], as_index=False)[varlist].sum()
        df = df.drop(columns=varlist)
        df2['re_m'] = df2['re_m'] / df2['size']
        if include_P:
            df2['P'] = df2['P'] / df2['size']
        df2 = df2.drop(columns=['size'])
    else:
        if include_P:
            df2 = df.groupby(['date', 'trial'], as_index=False)['re', 'P'].mean().rename(columns={'re': 're_m'})
            df = df.drop(columns=['P'])
        else:
            df2 = df.groupby(['date', 'trial'], as_index=False)['re'].mean().rename(columns={'re': 're_m'})

    return df, df2


@numba.jit
def _beta(array):
    """Calculate beta using nx2 array where first column is y-variable and second x-variable."""
    X = np.append(np.ones((array.shape[0], 1)), array[:, 1:], axis=1)
    return (np.linalg.inv(X.T @ X) @ X.T @ array[:, :1])[1].item()
